#!/usr/bin/env python
# coding: utf-8

# In[14]:


from datetime import datetime, timedelta
from netCDF4 import Dataset
import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import cmocean.cm as cmo
from matplotlib import ticker 
import glob 
import matplotlib as mpl
from wrf import getvar, interplevel, get_cartopy, vertcross, vinterp
from matplotlib.colors import LogNorm
import sys
from string import ascii_lowercase
from scipy import signal
import matplotlib.colors as mcolors

mpl.rcParams['figure.figsize'] = [10,10]
mpl.rcParams['figure.titlesize'] = 10
mpl.rcParams['figure.titleweight'] = 'bold'
mpl.rcParams['xtick.labelsize'] = 11
mpl.rcParams['ytick.labelsize'] = 11
mpl.rcParams['axes.labelsize'] = 11
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['lines.linewidth'] = 1.8
mpl.rcParams['grid.linewidth'] = .25
mpl.rcParams['figure.subplot.wspace'] = 0.05
mpl.rcParams['figure.subplot.hspace'] = 0.05
mpl.rcParams['legend.fontsize'] = 11
mpl.rcParams['legend.framealpha'] = .75
mpl.rcParams['legend.loc'] = 'best'
mpl.rcParams['savefig.bbox'] = 'tight'
mpl.rcParams['savefig.dpi'] = 600

def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero.

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower offset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax / (vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highest point in the colormap's range.
          Defaults to 1.0 (no upper offset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap

rdir = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/met/'
rdir2 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/turb/'
rdir3 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/fire/'
#rdir = '/glade/scratch/tjuliano/fire/sagehen/WRF-Fire-LEAPHI/WRF/test/marshall/2021123012/run/netcdf_2/'
dx = 111.111/1000


# ## For making the vector plots xs without the vectors

# In[ ]:



lowerc = list(iter(ascii_lowercase))
from matplotlib.colors import LogNorm
#for tidx in np.arange(240, 271, 30):#[240, 360, 480]:
sdir = 'xs/for_paper/xsect_u/'
loc = []
#for tidx in np.arange(480, 541, 10):#[240, 360, 480]:
#offset = [-20,-15,-10,-5,0,5,10,15,20]
offset = -10
dt = 40
dd = [30,30,30,30,31]
hh = [19,19,22,23,2]
mm = [0,45,5,0,30]
zinterp = np.arange(5,6005,5)
xlon_fr = [-105.02025,-105.008514,-105.17154,-105.19501,-105.18457]
for tidx in np.arange(len(dd)):
    time_og = datetime(2021, 12, dd[tidx], hh[tidx], mm[tidx], 0)
    print ('Doing ', time_og)
    time_arr = np.arange(-int(dt/2),int(dt/2)+1,1)
    for ww in np.arange(len(time_arr)):
        time = time_og + timedelta(minutes=int(time_arr[ww]))
        if time_arr[ww] == 0:
            time_tit = time_og + timedelta(minutes=int(time_arr[ww]))
        filename = 'wrfout_d02_%s' % time.strftime('%Y-%m-%d_%H:%M:%S')
        print (filename)
        with Dataset(rdir + filename) as nc:
            timestring = nc.variables['Times']
            xlat = getvar(nc, 'lat').values
            xlon = getvar(nc, 'lon')
        
            u        = getvar(nc, 'ua').values
            v        = getvar(nc, 'ua').values
            w        = getvar(nc, 'wa').values
            ph       = getvar(nc, 'PH').values
            phb      = getvar(nc,'PHB').values
            zht      = ((ph[1:]+ph[0:-1])/2. +(phb[1:]+phb[0:-1])/2.)/9.81
            hgt      = getvar(nc,'HGT').values

            if time_arr[ww] == 0:
                th       = getvar(nc, 'T').values
                th       = th + 300.

        if time_arr[ww] == 0:
            with Dataset(rdir2 + filename) as nc:
                tke      = getvar(nc, 'TKE').values
                m11      = getvar(nc, 'm11').values
                m22      = getvar(nc, 'm22').values
                m33      = getvar(nc, 'm33').values

        if time_arr[ww] == 0:
            with Dataset(rdir3 + filename) as nc:
                farea   = getvar(nc, 'FIRE_AREA').values
                fxlon   = getvar(nc, 'FXLONG').values

            indx = np.argmax(np.array(u))
            loc.append(np.unravel_index(indx, np.shape(u)))

            farea_rav = farea.ravel()
            xlon_rav = fxlon.ravel()

            if (np.nanmax(farea_rav) > 0.0):
                farea_idx = np.where(farea_rav==0.0)[0]
                xlon_rav[farea_idx] = np.nan
                xlon_fire = np.nanmax(xlon_rav)
            else:
                xlon_fire = -105.230189

        # cross section over the fire
        x = 262 + 15 + offset
        y = 350

        print (xlat[y,x])
        sys.exit()
    
        xlonv = np.meshgrid(xlon[x, :], range(u.shape[0]))[0]
        xlonv2 = xlon[x, :]
#        zv    = np.squeeze(zht[:,x, :])
#        uv    = np.squeeze(u[:,x,:])
#        vv    = np.squeeze(v[:,x,:])
#        wv    = np.squeeze(w[:,x,:])
        zv    = zht[:,x:x+2,:]
        hgtv  = hgt[x:x+2, :]
        uv    = u[:,x:x+2,:]
        vv    = v[:,x:x+2,:]
        wv    = w[:,x:x+2,:]
        if time_arr[ww] == 0:
            zv_mid = zht[:,x:x+2,:]
            tkev_mid = tke[:,x:x+2,:]
            m11_mid = m11[:,x:x+2,:]
            m22_mid = m22[:,x:x+2,:]
            m33_mid = m33[:,x:x+2,:]

#        zv2 = np.empty([np.shape(zv)[0],np.shape(zv)[1],np.shape(zv)[2]])
#        for gg in np.arange(np.shape(zv)[1]):
#            for hh in np.arange(np.shape(zv)[2]):
#                zv2[:,gg,hh] = zv[:,gg,hh] - hgtv[gg,hh]
        zv2 = zv

        uwnd2 = interplevel(uv,zv2,zinterp,meta=False)
        vwnd2 = interplevel(vv,zv2,zinterp,meta=False)
        wwnd2 = interplevel(wv,zv2,zinterp,meta=False)

        uwnd2 = uwnd2[:,0,:]
        vwnd2 = vwnd2[:,0,:]
        wwnd2 = wwnd2[:,0,:]

        if time_arr[ww] == 0:
            tkev_mid2 = interplevel(tkev_mid,zv2,zinterp,meta=False)
            m11_mid2 = interplevel(m11_mid,zv2,zinterp,meta=False)
            m22_mid2 = interplevel(m22_mid,zv2,zinterp,meta=False)
            m33_mid2 = interplevel(m33_mid,zv2,zinterp,meta=False)

            tkev_mid2 = tkev_mid2[:,0,:]
            m11_mid2 = m11_mid2[:,0,:]
            m22_mid2 = m22_mid2[:,0,:]
            m33_mid2 = m33_mid2[:,0,:]

        if ww == 0:
            uwnd2_arr = np.empty([np.shape(uwnd2)[0],np.shape(uwnd2)[1],len(time_arr)])
            vwnd2_arr = np.empty([np.shape(uwnd2)[0],np.shape(uwnd2)[1],len(time_arr)])
            wwnd2_arr = np.empty([np.shape(uwnd2)[0],np.shape(uwnd2)[1],len(time_arr)])

        uwnd2_arr[:,:,ww] = uwnd2
        vwnd2_arr[:,:,ww] = vwnd2
        wwnd2_arr[:,:,ww] = wwnd2

        if time_arr[ww] == 0:
            thv   = np.squeeze(th[:,x,:])

            ### COMPUTE PBLH
            pblh = np.empty(np.shape(thv)[1])
            for jj in np.arange(np.shape(thv)[1]):
                thv_grad = (thv[1:36,jj] - thv[0:35,jj]) / (zv[1:36,0,jj] - zv[0:35,0,jj])
                thv_grad_max = np.argmax(thv_grad)
                #thv_grad_max2 = np.argsort(thv_grad)
                #print (thv_grad)
                #print (thv_grad_max1, thv_grad_max2)
                pblh[jj] = zv[thv_grad_max,0,jj] - hgtv[0,jj]

            if tidx == 0:
                fig = plt.figure(figsize=(12, 9))

    uwnd2_arr = signal.detrend(uwnd2_arr,axis=2)
    vwnd2_arr = signal.detrend(vwnd2_arr,axis=2)
    wwnd2_arr = signal.detrend(wwnd2_arr,axis=2)

    uwnd2_arr_mean = np.mean(uwnd2_arr,axis=2)
    vwnd2_arr_mean = np.mean(vwnd2_arr,axis=2)
    wwnd2_arr_mean = np.mean(wwnd2_arr,axis=2)

    midpt = int(np.ceil(len(time_arr)/2))
    uwnd2_pert = uwnd2_arr[:,:,midpt] - uwnd2_arr_mean
    vwnd2_pert = vwnd2_arr[:,:,midpt] - vwnd2_arr_mean
    wwnd2_pert = wwnd2_arr[:,:,midpt] - wwnd2_arr_mean

    tke_res = 0.5 * (uwnd2_pert*uwnd2_pert + vwnd2_pert*vwnd2_pert + wwnd2_pert*wwnd2_pert)
    tke_sgs = m11_mid2 + m22_mid2 + m33_mid2 + tkev_mid2
    tke_sgs[tke_sgs<0.0] = 0.0
    tke_tot = tke_sgs + tke_res
    tke_tot[tke_tot>1e10] = 0.0

#    print (np.min(tke_sgs.ravel()),np.min(tke_res.ravel()),np.min(tke_tot.ravel()))
#    print (np.max(tke_sgs.ravel()),np.max(tke_res.ravel()),np.max(tke_tot.ravel()))

    axs = plt.subplot(3, 2, tidx+1)
#    fig.suptitle(time.strftime('%Y-%m-%d %H:%M:%SZ'))
    
    orig_cmap = cmo.thermal
    levs = [0,5,10,15,20,25,30,35,40,45,50,60,70,80,90,100,125,150,175,200,250,300]
#    colors = orig_cmap(np.linspace(0, 1, len(levs) - 1))
#    cmap, norm = mcolors.from_levels_and_colors(levs, colors)
    norm = matplotlib.colors.BoundaryNorm(levs, orig_cmap.N)
    c0 = axs.contourf(xlonv2, zinterp, tke_tot, cmap=orig_cmap, levels=levs, extend='max', norm=norm, spacing='uniform')
    c00 = axs.contour(xlonv, zv_mid[:,0,:], thv, np.arange(0,502,2), linewidths=1., colors='limegreen')
    axs.plot([xlon_fire,xlon_fire],[1500,6000],lw=2.,ls='--',c='magenta')
    axs.plot([xlon_fr[tidx],xlon_fr[tidx]],[1500,6000],lw=2.,ls=':',c='magenta')
#    axs.plot(xlonv[0,:], pblh+hgtv, lw=2., c='magenta')
    axs.clabel(c00,c00.levels[::2],fmt='%d')
    
#    axs[0].invert_yaxis()
#    axs[0].set_title('Zonal Wind' )
#    axs.set_title(time.strftime('%Y-%m-%d %H:%M:%SZ') + ', ' + str(round(np.mean(xlat[x,:]),3)))
    axs.set_title(time_tit.strftime('%Y-%m-%d %H:%M:%SZ'))
    axs.fill_between(xlon[x, :], 0, zv[0,0], facecolor="gray")

#    axs[0].set_ylim(axs[0].get_ylim()[0], 450)
    axs.set_ylim(1500, 6000)
#    axs[0].set_ylabel('Height AGL (m)')
    if tidx == 1 or tidx == 3:
        axs.set_yticks([2000,3000,4000,5000,6000])
        axs.set_yticklabels(['','','','',''])
    if tidx == 3 or tidx == 4:
        axs.set_xlabel('Longitude')
    if tidx == 0 or tidx == 2 or tidx == 4:
        axs.set_ylabel('Height ASL (m)')

#    cbar_ax = fig.add_axes([0.565, 0.2, 0.4, 0.05])
#    cbar = fig.colorbar(c0, cax=cbar_ax, orientation='horizontal', spacing='uniform')
#    cbar.ax.tick_params(labelsize=14)
#    cbar.set_ticks([0,25,50,100,150,200,300])
#    cbar.ax.set_xlabel('Total TKE (m$^{2}$ s$^{-2}$)',fontsize=20,labelpad=20)

#    plt.show()

plt.tight_layout()

cbar_ax = fig.add_axes([0.565, 0.2, 0.4, 0.05])
cbar = fig.colorbar(c0, cax=cbar_ax, orientation='horizontal', spacing='uniform')
cbar.ax.tick_params(labelsize=14)
cbar.set_ticks([0,25,50,100,150,200,300])
cbar.ax.set_xlabel('Total TKE (m$^{2}$ s$^{-2}$)',fontsize=20,labelpad=20)
 
#        cbar0 = plt.colorbar(c0, format='%.0f', ax = axs)
#        cbar0.ax.set_ylabel(r'$U$ (m s$^{-1}$)')

plt.savefig('xs_lat_tke_res_dt' + str(dt) + 'min_panel_%s' % time.strftime('%4Y-%m-%d_%H:%M:%S'), dpi=600, bbox_inches='tight')
plt.close()

# ## For making the GIF File

# In[99]:


from PIL import Image

sys.exit()

# filepaths
fp_in = "./xs/*.png"
fp_out = "./xs/xs_xx.gif"

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
#img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
#frames = [Image.open(image) for image in sorted(glob.glob(fp_in))]
#frame_one = frames[0]
all_files = sorted(glob.glob(fp_in))
frames = []
for i in np.arange(len(all_files)):
    frame = Image.open(all_files[i])
    frames.append(frame.copy())
#img.save(fp=fp_out, format='GIF', append_images=imgs,
#         save_all=True, duration=240, loop=0)
frames[0].save(fp=fp_out, format='GIF', append_images=frames[1:],
         save_all=True, duration=300, loop=0)


# In[ ]:




